import argparse
import os

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    Trainer,
    TrainingArguments,
)

parser = argparse.ArgumentParser()
parser.add_argument("dataset_folder", type=str)
parser.add_argument("--save_folder", type=str, default=None)
parser.add_argument(
    "--model_name", type=str, default="uer/albert-base-chinese-cluecorpussmall"
)
parser.add_argument("--do-train", action="store_true")
parser.add_argument("--do-eval", action="store_true")
parser.add_argument("--do-interactive", action="store_true")
args = parser.parse_args()

if args.save_folder is None:
    args.save_folder = os.path.join(args.dataset_folder, "output")

all_labels = open(os.path.join(args.dataset_folder, "label.txt")).read().splitlines()
label2id = {label: i for i, label in enumerate(all_labels)}
id2label = {i: label for i, label in enumerate(all_labels)}


tokenizer = AutoTokenizer.from_pretrained(args.model_name, model_max_length=512)

accuracy = evaluate.load("./accuracy.py")


def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["text"], truncation=True, max_length=512, padding=True
    )
    labels = [label2id[label] for label in examples["label"]]
    model_inputs["label"] = labels
    return model_inputs


train_dataset = load_dataset(
    "csv",
    data_files=os.path.join(args.dataset_folder, "train.txt"),
    delimiter="\t",
    column_names=["text", "label"],
    cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]
dev_dataset = load_dataset(
    "csv",
    data_files=os.path.join(args.dataset_folder, "dev.txt"),
    delimiter="\t",
    column_names=["text", "label"],
    cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]
test_dataset = load_dataset(
    "csv",
    data_files=os.path.join(args.dataset_folder, "test.txt"),
    delimiter="\t",
    column_names=["text", "label"],
    cache_dir=os.path.join(args.dataset_folder, "cache"),
).map(preprocess_function, batched=True)["train"]


data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)


def do_train():
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name,
        num_labels=len(all_labels),
        id2label=id2label,
        label2id=label2id,
    )
    # train with bf16
    training_args = TrainingArguments(
        output_dir=os.path.join(args.dataset_folder, "output"),
        learning_rate=1e-5,
        per_device_train_batch_size=48,
        per_device_eval_batch_size=48,
        num_train_epochs=30,
        weight_decay=0.01,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        warmup_steps=100,
        save_total_limit=2,
        bf16=True,
        report_to="wandb",
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
    )

    trainer.train()
    # trainer.train(
    #     resume_from_checkpoint=os.path.join(
    #         args.dataset_folder, "output/checkpoint-7700"
    #     ),
    # )
    trainer.save_model(args.save_folder)


def do_eval():
    model = AutoModelForSequenceClassification.from_pretrained(args.save_folder)
    # predict
    training_args = TrainingArguments(
        output_dir=os.path.join(args.dataset_folder, "output"),
        per_device_eval_batch_size=128,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    predictions = trainer.predict(dev_dataset)
    print("开发集测试结果：")
    print(predictions.metrics)
    predictions = trainer.predict(test_dataset)
    print("测试集测试结果：")
    print(predictions.metrics)

    # print sample results
    for i in range(len(test_dataset)):
        if (
            id2label[predictions.predictions[i].argmax()]
            != id2label[test_dataset[i]["label"]]
        ):
            print(
                id2label[predictions.predictions[i].argmax()],
                id2label[test_dataset[i]["label"]],
                len(
                    tokenizer.decode(
                        test_dataset[i]["input_ids"], skip_special_tokens=True
                    )
                ),
            )


def do_interactive():
    model = AutoModelForSequenceClassification.from_pretrained(args.save_folder)
    while True:
        text = input("Input text: ")
        inputs = tokenizer(text, return_tensors="pt")
        outputs = model(**inputs)
        print(id2label[outputs.logits[0].argmax(-1).item()])


if __name__ == "__main__":
    if args.do_train:
        do_train()
    if args.do_eval:
        do_eval()
    if args.do_interactive:
        do_interactive()
